#include <assert.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <errno.h>
#include <unistd.h>
#include <netdb.h>
#include <fcntl.h>
#include "utils.hpp"

// Create a UDP socket and bind it to a port.
int create_and_bind_udp(uint16_t port) {
  const int s = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
  if (s < 0) {
    fprintf(stderr, "error creating socket\n");
    return -1;
  }

  sockaddr_in servaddr;
  memset(&servaddr, 0, sizeof(servaddr));
  servaddr.sin_family      = AF_INET;
  servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
  servaddr.sin_port        = htons(port);

  // Other programs are probably also listening on port 5353, so we
  // need to specify that we are happy to share.
  const int trueValue = 1;
  setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &trueValue, sizeof(trueValue));

  if (bind(s, (sockaddr*)&servaddr, sizeof(servaddr)) < 0) {
    int err = errno;
    fprintf(stderr, "bind failed s=%d err=%d %s\n", s, err, strerror(err));
    return -1;
  }

  return s;
}

// Create a TCP socket and start listening on the specified port.
int create_bind_and_listen_tcp(uint16_t port) {
  // Create a socket for listening on the port.
  const int sock = socket(PF_INET, SOCK_STREAM, 0);
  if (sock < 0) {
    printf("Failed to create socket. Try running with sudo.\n");
    return -1;
  }

  // Allow the port to be reused as soon as the program terminates.
  int one = 1;
  const int r0 =
    setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one));
  if (r0 < 0) {
    printf("Failed to set SO_REUSEADDR\n.");
    return -1;
  }

  // Bind the port.
  sockaddr_in addr;
  memset(&addr, 0, sizeof(addr));
  addr.sin_port = htons(port);
  addr.sin_addr.s_addr = INADDR_ANY;

  if (bind(sock, (sockaddr*)&addr, sizeof(addr)) < 0) {
    int err = errno;
    printf(
      "Error binding TCP socket to port. Try running with sudo.\nerrno = %d %s\n",
      err, strerror(err)
    );
    return -1;
  }

  // Start listening.
  const int r1 = listen(sock, SOMAXCONN);
  if (r1 < 0) {
    printf("listen failed.\n");
    return -1;
  }

  return sock;
}

void print_addr(const sockaddr* peer_addr, socklen_t peer_addr_len) {
  char host[NI_MAXHOST], service[NI_MAXSERV];

  const int r = getnameinfo(
    peer_addr,
    peer_addr_len, host, NI_MAXHOST,
    service, NI_MAXSERV, NI_NUMERICSERV
  );
  if (r == 0) {
    printf("%s:%s", host, service);
  } else {
    printf("getnameinfo:%s", gai_strerror(r));
  }
}

int EPollHandlerInterface::epoll_add(const int epollfd) {
  printf("epoll_add socket %d handler=%p\n", sock_, this);
  epoll_event ev;
  ev.events = EPOLLIN | EPOLLOUT | EPOLLET;
  ev.data.ptr = (void*)this;
  if (epoll_ctl(epollfd, EPOLL_CTL_ADD, sock_, &ev) == -1) {
    printf("EPOLL_CTL_ADD failed");
    return -1;
  }
  return 0;
}

EPollHandlerInterface::~EPollHandlerInterface() {
  printf("closing socket %d handler=%p\n", sock_, this);
  close(sock_);
}

ssize_t EpollRecvHandlerUDP::replyto(
  const char* buf, size_t buflen,
  const sockaddr *dest_addr, socklen_t addrlen
) {
  // TODO: implement buffering for UDP sends, similar
  // to how it is done in EpollRecvHandlerTCP.
  return sendto(sock_, buf, buflen, MSG_NOSIGNAL, dest_addr, addrlen);
}

int EpollRecvHandlerUDP::build(
  const int epollfd, const int sock, RecvHandlerUDP* handler
) {
  if (sock < 0) {
    return -1;
  }

  if (!handler) {
    return -1;
  }

  // Make sure that the socket is non-blocking.
  const int flags = fcntl(sock, F_GETFL, 0);
  if (fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) {
    printf("set non-blocking failed");
    return -1;
  }

  EpollRecvHandlerUDP *h = new EpollRecvHandlerUDP(sock, handler);
  if (!h) {
    return -1;
  }

  if (h->epoll_add(epollfd) < 0) {
    delete h;
    return -1;
  }
  return 0;
}

EpollRecvHandlerUDP:: ~EpollRecvHandlerUDP() {
  printf("~EpollRecvHandlerUDP\n");
  delete handler_;
}

int EpollRecvHandlerUDP::process_read(const epoll_event *) {
  // Keep reading from the socket until there's no more data.
  while (1) {
    sockaddr_storage peer_addr;
    socklen_t peer_addr_len;
    uint8_t buf[4096];

    peer_addr_len = sizeof(sockaddr_storage);
    const ssize_t recvsize = recvfrom(
      sock_, buf, sizeof(buf), 0,
      (sockaddr *) &peer_addr, &peer_addr_len);

    if (recvsize < 0) {
      return 0;
    }

    const int r =
      handler_->receive(
        buf, recvsize, *this, (const sockaddr*)&peer_addr, peer_addr_len
      );
    if (r < 0) {
      return r;
    }
  }
}

EpollRecvHandlerTCP::EpollRecvHandlerTCP(
  const int sock, RecvHandlerTCP* handler
) :
  EPollHandlerInterface(sock),
  handler_(handler),
  recvbuf_(0),
  received_(0),
  remaining_(0),
  sendbuf_(0),
  sendbufsize_(0),
  sendbufpos_(0)
{}

EpollRecvHandlerTCP::~EpollRecvHandlerTCP() {
  printf("~EpollRecvHandlerTCP\n");
  delete[] sendbuf_;
  delete[] recvbuf_;
  delete handler_;
}

int EpollRecvHandlerTCP::accept() {
  ssize_t remaining = handler_->accept(*this);
  if (remaining < 0) {
    return -1;
  }
  remaining_ = static_cast<size_t>(remaining);
  recvbuf_ = new uint8_t[remaining_];
  if (!recvbuf_) {
    return -1;
  }
  return 0;
}

ssize_t EpollRecvHandlerTCP::reply(const char* buf, size_t buflen) {
  if (sendbuf_) {
    const size_t oldsize = sendbufsize_ - sendbufpos_;
    const size_t newsize = oldsize + buflen;
    uint8_t* newbuf = new uint8_t[newsize];
    if (!newbuf) {
      return -1;
    }
    memcpy(newbuf, sendbuf_ + sendbufpos_, oldsize);
    memcpy(newbuf + oldsize, buf, buflen);
    delete[] sendbuf_;
    sendbuf_ = newbuf;
    sendbufsize_ = newsize;
    sendbufpos_ = 0;
    return newsize;
  } else {
    // First try to send the message directly. If that's unsuccessful,
    // or we only manage to send part of the message, then we'll copy
    // the rest of the message into `sendbuf_` and wait for an EPOLLOUT
    // notification.
    ssize_t wr = send(sock_, buf, buflen, MSG_NOSIGNAL);
    if (wr < 0) {
      int err = errno;
      if (err == EAGAIN || err == EWOULDBLOCK) {
        // Copy everything to sendbuf_.
        wr = 0;
      } else {
        return -1;
      }
    }
    if (size_t(wr) == buflen) {
      // Successfully sent all the data.
      return buflen;
    }

    // Copy the remaining bytes to sendbuf_.
    sendbuf_ = new uint8_t[buflen - size_t(wr)];
    if (!sendbuf_) {
      return -1;
    }
    memcpy(sendbuf_, buf + size_t(wr), buflen - size_t(wr));
    sendbufsize_ = buflen - size_t(wr);
    sendbufpos_ = 0;
    return buflen;
  }
}

int EpollRecvHandlerTCP::build(
  const int epollfd, const int sock, RecvHandlerTCP* handler
) {
  if (sock < 0) {
    return -1;
  }

  if (!handler) {
    return -1;
  }

  // Make sure that the socket is non-blocking.
  const int flags = fcntl(sock, F_GETFL, 0);
  if (fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) {
    printf("set non-blocking failed");
    return -1;
  }

  EpollRecvHandlerTCP *h = new EpollRecvHandlerTCP(sock, handler);
  if (!h) {
    return -1;
  }

  if (h->epoll_add(epollfd) < 0) {
    delete h;
    return -1;
  }

  if (h->accept() < 0) {
    delete h;
    return -1;
  }

  return 0;
}

int EpollRecvHandlerTCP::process_read(const epoll_event *) {
  // Keep reading from the socket until there's no more data.
  while (remaining_ > 0) {
    sockaddr_storage peer_addr;
    socklen_t peer_addr_len;

    peer_addr_len = sizeof(sockaddr_storage);
    const ssize_t recvsize = recvfrom(
      sock_, recvbuf_ + received_, remaining_, 0,
      (sockaddr*)&peer_addr, &peer_addr_len);

    if (recvsize == 0) {
      // If we received zero bytes, then it means that our peer closed
      // the connection. Since remaining_ > 0, that means that something
      // has gone wrong.
      printf("Our peer closed the TCP socket unexpectedly: %d.\n", sock_);
      handler_->disconnect();
      return -1;
    }

    if (recvsize < 0) {
      int err = errno;
      if (err == EAGAIN || err == EWOULDBLOCK) {
        // Need to wait for more input. (We will get a notification from
        // epoll when that happens.)
        return 0;
      }
      printf("TCP read error. sock=%d err=%s.\n", sock_, strerror(err));
      handler_->disconnect();
      return -1;
    }

    printf("TCP socket %d (handle %p) received %ld bytes.\n", sock_, this, recvsize);

    // Check if we have received all the input.
    remaining_ -= recvsize;
    received_ += recvsize;
    if (remaining_ == 0) {
      const ssize_t r =
        handler_->receive(*this, recvbuf_);
      if (r == 0) {
        // We're not interested in reading any more data from the socket.
        // We might still have some data to send, though, so we don't close
        // the socket completely yet.
        shutdown(sock_, SHUT_RD);
      }
      if (r < 0) {
        return -1;
      }

      // Get the buffer and counters ready for the next message.
      delete[] recvbuf_;
      recvbuf_ = new uint8_t[r];
      received_ = 0;
      remaining_ = r;
    }
  }
  return 0;
}

int EpollRecvHandlerTCP::process_write(const epoll_event *) {
  // Keep writing to the socket until either we're done or the socket
  // blocks.
  while (sendbufpos_ < sendbufsize_) {
    const ssize_t wr = send(
      sock_, sendbuf_ + sendbufpos_, sendbufsize_ - sendbufpos_, MSG_NOSIGNAL);

    if (wr < 0)  {
      int err = errno;
      if (err == EAGAIN || err == EWOULDBLOCK) {
        // The socket buffer is full, so we need to wait for epoll
        // to tell us that we can resume sending.
        return 0;
      }
      printf("TCP send error. sock=%d err=%s.\n", sock_, strerror(err));
      handler_->disconnect();
      return -1;
    }

    sendbufpos_ += wr;
  }

  // All the data was successfully sent.
  delete[] sendbuf_;
  sendbuf_ = 0;
  sendbufsize_ = 0;
  sendbufpos_ = 0;

  if (remaining_ <= 0) {
    // We've finished sending, and there's nothing more to receive, so
    // close the socket.
    return -1;
  }
  return 0;
}

EpollTcpConnectHandler::EpollTcpConnectHandler(
  const int epollfd, const int sock, BuildRecvHandlerTCP* factory
) :
  EPollHandlerInterface(sock),
  epollfd_(epollfd),
  factory_(factory)
{}

EpollTcpConnectHandler::~EpollTcpConnectHandler() {
  delete factory_;
}

int EpollTcpConnectHandler::build(
  const int epollfd, const int sock, BuildRecvHandlerTCP* factory
) {
  if (sock < 0) {
    return -1;
  }

  if (!factory) {
    return -1;
  }

  // Make sure that the socket is non-blocking.
  const int flags = fcntl(sock, F_GETFL, 0);
  if (fcntl(sock, F_SETFL, flags | O_NONBLOCK) < 0) {
    printf("set non-blocking failed");
    return -1;
  }

  EpollTcpConnectHandler *h =
    new EpollTcpConnectHandler(epollfd, sock, factory);
  if (!h) {
    return -1;
  }

  if (h->epoll_add(epollfd) < 0) {
    delete h;
    return -1;
  }
  return 0;
}

int EpollTcpConnectHandler::process_read(const epoll_event *) {
  while (1) {
    sockaddr addr;
    socklen_t addr_len = sizeof(addr);
    const int s = accept(sock_, &addr, &addr_len);
    if (s < 0) {
      return 0; // No need to close the listener down.
    }

    printf("accepting TCP connection sock=%d\n", s);

    RecvHandlerTCP* handler = factory_->build(&addr, addr_len);
    if (!handler) {
      printf("factory failed.\n");
      close(s);
      continue; // No need to close the listener down.
    }

    if (EpollRecvHandlerTCP::build(epollfd_, s, handler) < 0) {
      printf("Could not register accepted socket with epoll.\n");
      close(s);
      continue; // No need to close the listener down.
    }
  }
  return 0;
}

int EpollTcpConnectHandler::process_write(const epoll_event *) {
  // This socket is only for listening, so this method never has to do
  // anything.
  return 0;
}

TCP_Cache::TCP_Cache() : sendbuf_(0), total_send_(0), total_receive_(0) {}

TCP_Cache::~TCP_Cache() {
  free(sendbuf_);
}

bool TCP_Cache::isInitialized() const { return sendbuf_ != 0; }

// Called by EpsonHandlerCacheProxyTCP after it has collected
// the relevant data.
void TCP_Cache::init(char* sendbuf, size_t total_send, size_t total_receive) {
  if (sendbuf) {
    sendbuf_ = sendbuf;
    total_send_ = total_send;
    total_receive_ = total_receive;
  }
}

ssize_t TCP_Cache::accept(SocketHandlerTCP& sock) {
  assert(isInitialized());
  printf(
    "Using cache.  send: %lu  receive: %lu\n",
    total_send_, total_receive_
  );

  const ssize_t wr = sock.reply(sendbuf_, total_send_);
  if (wr < 0) {
    const int err = errno;
    printf("send failed: %s\n", strerror(err));
    return -1;
  }

  return total_receive_;
}

TCP_Cache_Playback::TCP_Cache_Playback(TCP_Cache& cache) :
  cache_(cache)
{}

ssize_t TCP_Cache_Playback::accept(SocketHandlerTCP& sock) {
  return cache_.accept(sock);
}

class TCP_Cache_Record::SocketHandlerProxy : public SocketHandlerTCP {
  // Reference to enclosing class.
  TCP_Cache_Record& parent_;

  // The object that we are proxying.
  SocketHandlerTCP& sock_;

public:
  SocketHandlerProxy(
    TCP_Cache_Record& parent, SocketHandlerTCP& sock
  ) :
    parent_(parent), sock_(sock)
  {}

  ssize_t reply(const char* buf, size_t buflen) override {
    const ssize_t wr = sock_.reply(buf, buflen);
    if (wr >= 0 && parent_.sendbuf_) {
      const size_t newlen = parent_.total_send_ + buflen;
      parent_.sendbuf_ = (char*)realloc(parent_.sendbuf_, newlen);
      if (parent_.sendbuf_) {
        memcpy(parent_.sendbuf_ + parent_.total_send_, buf, buflen);
        parent_.total_send_ += buflen;
      }
    }
    return wr;
  }
};

TCP_Cache_Record::TCP_Cache_Record(
  RecvHandlerTCP* handler, TCP_Cache& cache
) :
  handler_(handler),
  cache_(cache),
  sendbuf_((char*)malloc(1)), // Use malloc so that we can realloc
  total_send_(0),
  total_receive_(0)
{}

TCP_Cache_Record::~TCP_Cache_Record() {
  free(sendbuf_);
  delete handler_;
}

ssize_t TCP_Cache_Record::accept(SocketHandlerTCP& sock) {
  SocketHandlerProxy proxy(*this, sock);
  const ssize_t r = handler_->accept(proxy);
  total_receive_ = r;
  return r;
}

ssize_t TCP_Cache_Record::receive(SocketHandlerTCP& sock, const uint8_t* buf) {
  SocketHandlerProxy proxy(*this, sock);
  const ssize_t r = handler_->receive(proxy, buf);
  total_receive_ += r;
  if (r == 0) {
    // We completed successfully, so save everything to the cache.
    cache_.init(sendbuf_, total_send_, total_receive_);
    sendbuf_ = 0;
    total_send_ = 0;
    total_receive_ = 0;
  }
  return r;
}

[[ noreturn ]] void epoll_main_loop(const int epollfd) {
  while (1) {
    const size_t max_events = 10;
    epoll_event events[max_events];
    const int numevents = epoll_wait(epollfd, events, max_events, -1);
    int eventidx;
    for (eventidx = 0; eventidx < numevents; eventidx++) {
      epoll_event *ev = &events[eventidx];
      EPollHandlerInterface *handler = (EPollHandlerInterface *)ev->data.ptr;

      if (ev->events & EPOLLIN) {
        if (handler->process_read(ev) < 0) {
          printf("shutdown handler=%p\n", handler);
          delete handler;  // This also closes the file descriptor
          continue;
        }
      }

      if (ev->events & EPOLLOUT) {
        if (handler->process_write(ev) < 0) {
          printf("shutdown handler=%p\n", handler);
          delete handler;  // This also closes the file descriptor
          continue;
        }
      }

      if (ev->events & (EPOLLERR | EPOLLHUP)) {
        printf("epoll error 0x%x on handler=%p\n", ev->events, handler);
        delete handler;  // This also closes the file descriptor
        continue;
      }
    }
  }
}
